import copy
import functools
import io
import itertools
import json
import operator
import secrets
import tempfile
import uuid
from base64 import b64decode, b64encode

from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.db import models, transaction
from django.db.models.functions import Coalesce
from django.utils import timezone
from django.utils.crypto import get_random_string

from sysreptor import signals as sysreptor_signals
from sysreptor.pentests import cvss
from sysreptor.pentests.fielddefinition.predefined_fields import FINDING_FIELDS_CORE, FINDING_FIELDS_PREDEFINED
from sysreptor.users.models import PentestUser
from sysreptor.utils import crypto
from sysreptor.utils.configuration import configuration
from sysreptor.utils.crypto import pgp
from sysreptor.utils.crypto.secret_sharing import ShamirLarge
from sysreptor.utils.crypto.storage import EncryptedFileAdapter, IterableToFileAdapter
from sysreptor.utils.fielddefinition.types import FieldDefinition, parse_field_definition
from sysreptor.utils.files import normalize_filename
from sysreptor.utils.history import (
    bulk_create_with_history,
    bulk_delete_with_history,
    bulk_update_with_history,
    history_context,
)
from sysreptor.utils.models import SubqueryCount
from sysreptor.utils.utils import get_key_or_attr, groupby_to_dict, omit_keys, set_key_or_attr


class TagQuerysetMixin:
    def get_all_tags(self):
        return self \
            .annotate(tag=models.Func(models.F('tags'), function='unnest')) \
            .values_list('tag', flat=True) \
            .order_by('tag') \
            .distinct()


class ProjectTypeQueryset(TagQuerysetMixin, models.QuerySet):
    def only_permitted(self, user):
        if not user or user.is_anonymous or user.is_system_user:
            return self.none()
        if user.is_admin:
            return self
        pt_filters = models.Q(models.Q(linked_project=None) & models.Q(linked_user=None))
        if user.is_project_admin:
            pt_filters |= models.Q(linked_project__isnull=False)
        else:
            pt_filters |= models.Q(linked_project__members__user=user)
        if configuration.ENABLE_PRIVATE_DESIGNS:
            pt_filters |= models.Q(linked_user=user)
        return self.filter(pt_filters)

    def only_global(self):
        return self \
            .filter(linked_project=None) \
            .filter(linked_user=None)

    def annotate_scope_order(self):
        return self \
            .annotate(scope_order=models.Case(
                models.When(models.Q(linked_project__isnull=False), then=1),
                models.When(models.Q(linked_project=None) & models.Q(linked_user=None), then=2),
                default=3,
            ))

    def annotate_status_order(self):
        from sysreptor.pentests.models import ReviewStatus
        return self \
            .annotate(status_order=models.Case(
                models.When(status=ReviewStatus.FINISHED.value, then=1),
                models.When(status=ReviewStatus.IN_PROGRESS.value, then=2),
                *[models.When(status=d['id'], then=10 + idx) for idx, d in enumerate(ReviewStatus.get_definitions())],
                models.When(status=ReviewStatus.DEPRECATED.value, then=1000),
                default=999,
            ))

    def custom_finding_field_definitions(self):
        """
        Return all custom field definitions over all globally visible ProjectTypes.
        Handle conflicting data types of custom fields by using the field of the first ProjectType
        e.g. ProjectType1 defines custom_field: string; ProjectType2 defines custom_field: list[string] => use custom_field: string
        """
        all_finding_field_definitions = self \
            .only_global() \
            .order_by('created', 'id') \
            .values_list('finding_fields', flat=True)

        all_fields = {}
        for f in itertools.chain(*all_finding_field_definitions):
            if f['id'] not in all_fields:
                all_fields[f['id']] = f
        return parse_field_definition(all_fields.values())

    def increment_usage_count(self, by=1):
        return self.update(usage_count=models.F('usage_count') + models.Value(by))


class ProjectTypeManager(models.Manager.from_queryset(ProjectTypeQueryset)):
    use_in_migrations = True

    @transaction.atomic()
    @history_context(history_change_reason='Duplicated', prevent_notifications=True)
    def copy(self, instance, **kwargs):
        from sysreptor.pentests.models import UploadedAsset

        assets = list(instance.assets.all())

        # Copy model
        instance = copy.copy(instance)
        instance.created = timezone.now()
        instance.usage_count = 0
        for k, v in (kwargs or {}).items():
            setattr(instance, k, v)
        instance.copy_of_id = instance.pk
        instance.pk = None
        instance.lock_info_data = None
        instance.save()

        # Copy all assets
        for a in assets:
            a.pk = None
            a.linked_object = instance
        bulk_create_with_history(UploadedAsset, assets)

        # Send signal
        sysreptor_signals.post_create.send(sender=instance.__class__, instance=instance)

        return instance


class PentestProjectQueryset(TagQuerysetMixin, models.QuerySet):
    def only_permitted(self, user):
        if not user or user.is_anonymous or user.is_system_user:
            return self.none()
        if user.is_admin or user.is_project_admin:
            return self
        return self.filter(members__user=user)

    def only_archivable(self):
        from sysreptor.pentests.models import ArchivedProject
        return ArchivedProject.objects.filter_projects_can_be_archived(self)


class PentestProjectManager(models.Manager.from_queryset(PentestProjectQueryset)):
    @transaction.atomic()
    def copy(self, instance, **kwargs):
        from sysreptor.pentests.models import (
            Comment,
            CommentAnswer,
            PentestFinding,
            ProjectNotebookExcalidrawFile,
            ProjectNotebookPage,
            ReportSection,
            SourceEnum,
            UploadedImage,
            UploadedProjectFile,
        )
        with history_context(history_change_reason=f'Duplicated from project "{instance.name}"', prevent_notifications=True):
            findings = list(instance.findings.all())
            sections = list(instance.sections.all())
            notes = list(instance.notes.select_related('parent').all())
            excalidraw = list(ProjectNotebookExcalidrawFile.objects.filter(linked_object__project=instance).select_related('linked_object'))
            members = list(instance.members.all())
            images = list(instance.images.all())
            files = list(instance.files.all())
            comments = list(Comment.objects.filter_project(instance).select_related('section', 'finding').prefetch_related('answers').all())

            # Copy project
            instance = copy.copy(instance)
            instance.created = timezone.now()
            for k, v in (kwargs or {}).items():
                setattr(instance, k, v)
            instance.copy_of_id = instance.pk
            instance.pk = None
            instance.readonly = False
            instance.project_type = instance.project_type.copy(
                linked_user=None,
                source=SourceEnum.SNAPSHOT if instance.project_type.source not in [SourceEnum.IMPORTED_DEPENDENCY, SourceEnum.CUSTOMIZED] else instance.project_type.source,
                usage_count=1,
            )
            instance.skip_post_create_signal = True
            instance.save()
            instance.project_type.linked_project = instance
            instance.project_type.save(update_fields=['linked_project'])

            for mi in members:
                mi.pk = None
                mi.project = instance
            instance.set_members(members, new=True)

            # Copy sections
            ReportSection.objects.filter(project=instance).delete()
            for s in sections:
                s.pk = None
                s.project = instance
            ReportSection.history.filter(project_id=instance.id).filter(section_id__in=[s.section_id for s in sections]).delete()
            bulk_create_with_history(ReportSection, sections)

            # Copy findings
            for f in findings:
                f.pk = None
                f.project = instance
            bulk_create_with_history(PentestFinding, findings)

            # Copy notes
            for n in notes:
                n.pk = uuid.uuid4()
                n.project = instance
            # Update parent to copied model
            for n in notes:
                if n.parent:
                    n.parent = next(filter(lambda pn: pn.note_id == n.parent.note_id, notes), None)
            bulk_create_with_history(ProjectNotebookPage, notes)

            # Copy excalidraw
            for e in excalidraw:
                e.pk = None
                e.linked_object = next(filter(lambda n: n.note_id == e.linked_object.note_id, notes))
            bulk_create_with_history(ProjectNotebookExcalidrawFile, excalidraw)

            # Copy images
            for i in images:
                i.pk = None
                i.linked_object = instance
            bulk_create_with_history(UploadedImage, images)

            # Copy files
            for f in files:
                f.pk = None
                f.linked_object = instance
            bulk_create_with_history(UploadedProjectFile, files)

            # Copy comments
            comment_answers = []
            for c in comments:
                answers = list(c.answers.all())
                c.pk = None
                c.section = next(filter(lambda s: c.section and c.section.section_id == s.section_id, sections), None)
                c.finding = next(filter(lambda f: c.finding and c.finding.finding_id == f.finding_id, findings), None)

                for ca in answers:
                    ca.pk = None
                    ca.comment = c
                    comment_answers.append(ca)
            bulk_create_with_history(Comment, comments)
            bulk_create_with_history(CommentAnswer, comment_answers)

            # Send signal
            sysreptor_signals.post_create.send(sender=instance.__class__, instance=instance)

            return instance

    @transaction.atomic
    @history_context()
    def set_members(self, instance, members, new=False):
        from sysreptor.pentests.models import ProjectMemberInfo

        if members is None:
            return

        for m in members:
            m.pk = None
            m.project = instance
            m.roles = list(set(m.roles))

        members_map = dict(map(lambda m: (m.user_id, m), members))
        existing_members_map = {} if new else dict(map(lambda m: (m.user_id, m), instance.members.all()))

        if new_members := omit_keys(members_map, existing_members_map.keys()).values():
            bulk_create_with_history(ProjectMemberInfo, new_members, send_signals=True)
        if removed_members := omit_keys(existing_members_map, members_map.keys()).values():
            bulk_delete_with_history(ProjectMemberInfo, removed_members)

        updated_members = []
        for k, m in existing_members_map.items():
            if k in members_map and set(m.roles) != set(members_map[k].roles):
                m.roles = members_map[k].roles
                m.updated = timezone.now()
                updated_members.append(m)
        if updated_members:
            bulk_update_with_history(ProjectMemberInfo, updated_members, fields=['roles', 'updated'], send_signals=True)

    @history_context()
    def add_member(self, user, projects):
        from sysreptor.pentests.models import ProjectMemberInfo

        existing_members = set(ProjectMemberInfo.objects \
            .filter(project__in=projects) \
            .filter(user=user) \
            .values_list('project_id', flat=True))
        new_members = [ProjectMemberInfo(user=user, project=p) for p in projects if p.id not in existing_members]
        bulk_create_with_history(ProjectMemberInfo, new_members)


class PentestFindingManager(models.Manager.from_queryset(models.QuerySet)):
    def create(self, project=None, data=None, order=None, **kwargs):
        from sysreptor.pentests.models import PentestFinding

        if project and not order:
            order = self \
                .filter(project=project) \
                .aggregate(max_order=Coalesce(models.Max('order'), models.Value(0)))['max_order'] + 1
        instance = PentestFinding(project=project, order=order, **kwargs)
        if data is not None:
            instance.update_data(data)
        instance.save()
        return instance

    def update_order(self, instances, missing_instances=None):
        from sysreptor.pentests.models import PentestFinding

        missing_instances = missing_instances or []
        findings_sorted = sorted(filter(lambda f: f not in missing_instances, instances), key=lambda f: f.order) + \
            sorted(filter(lambda f: f in missing_instances, instances), key=lambda f: f.order)
        for idx, f in enumerate(findings_sorted):
            f.order = idx + 1
            f.updated = timezone.now()
        bulk_update_with_history(PentestFinding, instances, fields=['order', 'updated'])


class ReportSectionManager(models.Manager.from_queryset(models.QuerySet)):
    pass


class CommentQueryset(models.QuerySet):
    def filter_project(self, project):
        return self.filter(models.Q(section__project=project) | models.Q(finding__project=project))


class CommentManager(models.Manager.from_queryset(CommentQueryset)):
    pass


class FindingTemplateQueryset(TagQuerysetMixin, models.QuerySet):
    def increment_usage_count(self, by=1):
        return self.update(usage_count=models.F('usage_count') + models.Value(by))

    def get_field_definition(self):
        from sysreptor.pentests.models import ProjectType

        custom_fields = ProjectType.objects.custom_finding_field_definitions()
        definition = FINDING_FIELDS_CORE | FINDING_FIELDS_PREDEFINED
        definition |= FieldDefinition(fields=[f for f in custom_fields.fields if f.id not in definition])
        for f in definition.fields:
            f.extra_info |= {
                'used_in_designs': f.id in custom_fields,
            }
        return definition

    def order_by_language(self, language):
        from sysreptor.pentests.models import FindingTemplateTranslation

        return self \
            .annotate(has_language=models.Exists(FindingTemplateTranslation.objects.filter(language=language).filter(template=models.OuterRef('pk')))) \
            .order_by('-has_language')

    def annotate_risk_level_number(self):
        return self \
            .annotate(risk_level_number=models.Case(
                models.When(main_translation__risk_level=cvss.CVSSLevel.CRITICAL.value, then=5),
                models.When(main_translation__risk_level=cvss.CVSSLevel.HIGH.value, then=4),
                models.When(main_translation__risk_level=cvss.CVSSLevel.MEDIUM.value, then=3),
                models.When(main_translation__risk_level=cvss.CVSSLevel.LOW.value, then=2),
                default=1)) \
            .annotate(risk_score_number=Coalesce(models.F('main_translation__risk_score'), 0.0))

    def annotate_search_rank(self, search_terms: list[str], sub_ranks=False):
        qs = self
        for idx, term in enumerate(search_terms):
            qs = qs.annotate(**{
                f'search_term_{idx}_matches_tags': models.Case(models.When(tags__icontains=term, then=1.0), default=0.0),
                f'search_term_{idx}_matches_title': models.Case(models.When(translations__title__icontains=term, then=1.0), default=0.0),
                f'search_term_{idx}_matches_data': models.Case(models.When(translations__custom_fields__icontains=term, then=0.2), default=0.0),
                f'search_term_{idx}_rank': models.F(f'search_term_{idx}_matches_tags') + models.F(f'search_term_{idx}_matches_title') + models.F(f'search_term_{idx}_matches_data'),
            })
        qs = qs \
            .annotate(search_rank=functools.reduce(operator.add, [models.F(f'search_term_{idx}_rank') for idx in range(len(search_terms))]))
        if sub_ranks:
            return qs
        else:
            return self.annotate(search_rank=models.Subquery(
                qs.filter(id=models.OuterRef('id'))
                .annotate(search_rank_max=models.Max('search_rank'))
                .values('search_rank_max')))

    def search(self, search_terms: list[str]):
        qs = self.annotate_search_rank(search_terms, sub_ranks=True)
        for idx in range(len(search_terms)):
            qs = qs.filter(**{f'search_term_{idx}_rank__gt': 0})
        qs = qs.values_list('id', flat=True)

        order_by = ('-search_rank',)
        if qs.query.order_by == ('-has_language',):
            order_by = qs.query.order_by + order_by
        return self \
            .filter(id__in=qs) \
            .annotate_search_rank(search_terms) \
            .order_by(*order_by)


class FindingTemplateManager(models.Manager.from_queryset(FindingTemplateQueryset)):
    @transaction.atomic()
    @history_context(history_change_reason='Duplicated', prevent_notifications=True)
    def copy(self, instance, **kwargs):
        from sysreptor.pentests.models import FindingTemplateTranslation, UploadedTemplateImage

        translations = list(instance.translations.all())
        images = list(instance.images.all())

        # Copy model
        instance = copy.copy(instance)
        instance.created = timezone.now()
        instance.usage_count = 0
        for k, v in (kwargs or {}).items():
            setattr(instance, k, v)
        instance.copy_of_id = instance.pk
        instance.pk = None
        instance.main_translation = None
        instance.lock_info_data = None
        instance.skip_post_create_signal = True
        instance.save_without_historical_record()

        # Copy translations
        for t in translations:
            is_main = t.is_main
            t.pk = None
            t.template = instance
            if is_main:
                instance.main_translation = t
        bulk_create_with_history(FindingTemplateTranslation, translations)

        # Set main translation
        instance._history_type = '+'
        instance.save()
        del instance._history_type

        # Copy images
        for i in images:
            i.pk = None
            i.linked_object = instance
        bulk_create_with_history(UploadedTemplateImage, images)

        # Send signal
        sysreptor_signals.post_create.send(sender=instance.__class__, instance=instance)

        return instance


class FindingTemplateTranslationQueryset(models.QuerySet):
    def default_order(self):
        return self \
            .annotate(is_main_order=models.Q(id=models.F('template__main_translation_id'))) \
            .order_by('-is_main_order', 'created')


class NotebookPageManagerBase:
    def create(self, project=None, user=None, order=None, parent=None, **kwargs):
        if not order and (project or user):
            if project:
                order_qs = self.filter(project=project)
            elif user:
                order_qs = self.filter(user=user)
            order = order_qs \
                .filter(parent=parent) \
                .aggregate(max_order=Coalesce(models.Max('order'), models.Value(0)))['max_order'] + 1

        if project:
            kwargs['project'] = project
        if user:
            kwargs['user'] = user

        return super().create(parent=parent, order=order, **kwargs)

    @transaction.atomic()
    @history_context(history_change_reason='Duplicated', prevent_notifications=True)
    def copy(self, instance, **kwargs):
        from sysreptor.pentests.models import NoteType, ProjectNotebookPage

        if isinstance(instance, ProjectNotebookPage):
            original_notes = list(instance.project.notes.prefetch_related('excalidraw_file').all())
        else:
            original_notes = list(instance.user.notes.prefetch_related('excalidraw_file').all())

        notes_to_create = []
        notes_to_update = []
        excalidraw_to_create = []

        def copy_note_model(note, **kwargs):
            # Copy model
            original_id = note.id
            note = copy.copy(note)
            note.created = timezone.now()
            note.updated = timezone.now()
            note.note_id = uuid.uuid4()
            for k, v in (kwargs or {}).items():
                setattr(note, k, v)
            note.pk = None
            notes_to_create.append(note)

            # Copy excalidraw file
            if note.type == NoteType.EXCALIDRAW:
                try:
                    note.excalidraw_file.pk = None
                    note.excalidraw_file.linked_object = note
                    excalidraw_to_create.append(note.excalidraw_file)
                except ObjectDoesNotExist:
                    pass

            # Copy children
            for n in original_notes:
                if n.parent_id == original_id:
                    copy_note_model(n, parent=note)

            return note

        # Copy model
        instance = copy_note_model(instance, **kwargs)

        # Update order: insert below original note
        instance.order += 1
        for n in original_notes:
            if n.parent_id == instance.parent_id and n.order >= instance.order:
                n.order += 1
                n.updated = timezone.now()
                notes_to_update.append(n)

        bulk_update_with_history(self.model, notes_to_update, fields=['order', 'updated'])
        bulk_create_with_history(self.model, notes_to_create, send_signals=True)
        if excalidraw_to_create:
            bulk_create_with_history(excalidraw_to_create[0]._meta.model, excalidraw_to_create)

        return instance

    def check_parent_and_order(self, instances, missing_instances=None):
        # * Update order values: first all notes in data, then missing notes (keep order of missing notes, but move to end)
        # * and validate no circular dependencies: beginning from the tree root, every note must be in the tree.
        #   If it does not have a path from root to node, there is a circular dependency.
        missing_instances = missing_instances or []
        def process_tree_layer(layer):
            for idx, n in enumerate(sorted(layer, key=lambda n: (n['note'] in missing_instances, get_key_or_attr(n['note'], 'order')))):
                set_key_or_attr(n['note'], 'order', idx + 1)
                process_tree_layer(n['children'])
        process_tree_layer(self.to_tree(instances))

    def to_tree(self, instances):
        parent_dict = groupby_to_dict(instances, key=lambda n: str((n.get('parent') if isinstance(n, dict) else n.parent_id) or ''))
        in_tree = set()
        def format_tree_layer(parent_id):
            out = []
            for n in sorted(parent_dict.get(str(parent_id or ''), []), key=lambda n: get_key_or_attr(n, 'order') or 0):
                note_id = get_key_or_attr(n, 'id')
                in_tree.add(note_id)
                out.append({
                    'note': n,
                    'children': format_tree_layer(note_id),
                })
            return out
        tree = format_tree_layer(None)
        if len(in_tree) != len(instances):
            raise ValidationError('Circular parent relationships detected')
        return tree

    def to_ordered_list_flat(self, instances):
        def flatten_note_tree(layer):
            out = []
            for n in layer:
                out.append(n['note'])
                out.extend(flatten_note_tree(n.get('children', [])))
            return out
        return flatten_note_tree(self.to_tree(instances))

class ProjectNotebookPageQuerySet(models.QuerySet):
    def annotate_is_shared(self):
        from sysreptor.pentests.models import ShareInfo

        share_infos = ShareInfo.objects \
            .filter(note=models.OuterRef('pk')) \
            .only_active()
        return self \
            .annotate(is_shared=models.Exists(share_infos))

    def only_shared(self):
        return self \
            .annotate_is_shared() \
            .filter(is_shared=True)


class ProjectNotebookPageManager(NotebookPageManagerBase, models.Manager.from_queryset(ProjectNotebookPageQuerySet)):
    def child_notes_of(self, note):
        instances = note.project.notes.all()

        parent_dict = groupby_to_dict(instances, key=lambda n: str(n.parent_id) if n.parent_id else '')
        def get_children(parent):
            yield parent
            for child in parent_dict.get(str(parent.id), []):
                yield from get_children(child)

        return self.filter(id__in=[n.id for n in get_children(note)])


class UserNotebookPageManager(NotebookPageManagerBase, models.Manager.from_queryset(models.QuerySet)):
    pass


class ShareInfoQuerySet(models.QuerySet):
    def only_active(self):
        if configuration.DISABLE_SHARING:
            return self.none()
        return self \
            .filter(expire_date__gte=timezone.now().date()) \
            .filter(is_revoked=False)

    def increment_failed_password_attempts(self):
        return self.update(failed_password_attempts=models.F('failed_password_attempts') + 1)


class ShareInfoManager(models.Manager.from_queryset(ShareInfoQuerySet)):
    pass


class UserPublicKeyQuerySet(models.QuerySet):
    def only_enabled(self):
        return self.filter(enabled=True)


class UserPublicKeyManager(models.Manager.from_queryset(UserPublicKeyQuerySet)):
    def create(self, public_key=None, public_key_info=None, **kwargs):
        if not public_key_info and public_key:
            public_key_info = pgp.public_key_info(public_key)
        return super().create(public_key=public_key, public_key_info=public_key_info, **kwargs)


class ArchivedProjectQuerySet(TagQuerysetMixin, models.QuerySet):
    def only_permitted(self, user):
        if not user or user.is_anonymous or user.is_system_user:
            return self.none()
        if user.is_admin or user.is_project_admin:
            return self
        return self.filter(key_parts__user=user)


class ArchivedProjectManager(models.Manager.from_queryset(ArchivedProjectQuerySet)):
    def get_possible_archive_users_for_project(self, project, always_include_members=False):
        from sysreptor.pentests.models import UserPublicKey

        can_archive = models.Q(is_global_archiver=True)
        if configuration.PROJECT_MEMBERS_CAN_ARCHIVE_PROJECTS or always_include_members:
            can_archive |= models.Q(pk__in=project.members.values_list('user_id'))
        return PentestUser.objects \
            .filter(can_archive) \
            .prefetch_related(models.Prefetch('public_keys', UserPublicKey.objects.only_enabled()))

    def get_archive_users_for_project(self, project):
        return self.get_possible_archive_users_for_project(project) \
            .only_active() \
            .only_with_public_keys()

    def filter_projects_can_be_archived(self, projects):
        can_archive = models.Q(is_global_archiver=True)
        if configuration.PROJECT_MEMBERS_CAN_ARCHIVE_PROJECTS:
            can_archive |= models.Q(projectmemberinfo__project=models.OuterRef('pk'))

        archive_user_count = PentestUser.objects \
            .filter(can_archive) \
            .only_active() \
            .only_with_public_keys()

        return projects \
            .annotate(archive_user_count=SubqueryCount(archive_user_count)) \
            .filter(archive_user_count__gte=int(configuration.ARCHIVING_THRESHOLD))

    @transaction.atomic()
    def create_from_project(self, project, name=None, users=None, delete_project=True):
        from sysreptor.pentests.models import (
            ArchivedProject,
            ArchivedProjectKeyPart,
            ArchivedProjectPublicKeyEncryptedKeyPart,
        )

        name = name or project.name
        users = list(users or self.get_archive_users_for_project(project))
        if len(users) < int(configuration.ARCHIVING_THRESHOLD):
            raise ValueError('Too few users')

        archive = ArchivedProject(
            name=name,
            tags=project.tags,
            threshold=int(configuration.ARCHIVING_THRESHOLD),
        )
        key_parts_to_create = []
        encrypted_key_parts_to_create = []

        # Create a random AES-256 key for encrypting the whole archive
        aes_key = secrets.token_bytes(32)
        # Split the AES key using shamir secret sharing and distribute key parts to users
        shamir_key_parts = ShamirLarge.split_large(k=archive.threshold, n=len(users), secret=aes_key)
        for user, (shamir_key_id, shamir_key) in zip(users, shamir_key_parts, strict=False):
            # Encrypt the per-user shamir key with a per-user AES key
            # This is mainly used for integrity protection to detect corrupted/user-forged shamir key parts.
            # This additional encryption layer makes it possible to other public key encryptions
            # other than PGP (which uses its own file encryption layer on top of public keys) in the future.
            user_aes_key = secrets.token_bytes(32)
            shamir_key_part_data_io = io.BytesIO()
            with crypto.open(shamir_key_part_data_io, mode='wb', key=crypto.EncryptionKey(id=None, key=user_aes_key)) as c:
                c.write(json.dumps({'key_id': shamir_key_id, 'key': b64encode(shamir_key).decode()}).encode())

            key_part_model = ArchivedProjectKeyPart(
                archived_project=archive,
                user=user,
                encrypted_key_part=shamir_key_part_data_io.getvalue(),
            )
            key_parts_to_create.append(key_part_model)

            # Encrypt the per-user AES key with each user's public key
            user_public_keys = [pk for pk in user.public_keys.all() if pk.enabled]
            if not user_public_keys:
                raise ValueError('User does not have any usable public key')
            for public_key in user_public_keys:
                encrypted_key_parts_to_create.append(ArchivedProjectPublicKeyEncryptedKeyPart(
                    key_part=key_part_model,
                    public_key=public_key,
                    encrypted_data=public_key.encrypt(
                        data=f'Use following key for project "{archive.name}"\n'.encode() + b64encode(user_aes_key) + b'\n'),
                ))

        # export archive and encrypt with AES-256 key and upload to storage
        from sysreptor.pentests.import_export import export_projects
        archive.file = EncryptedFileAdapter(
            file=IterableToFileAdapter(export_projects([project], export_all=True), name=str(uuid.uuid4())),
            key=crypto.EncryptionKey(id=None, key=aes_key),
        )

        # Create models in DB
        archive.save()
        ArchivedProjectKeyPart.objects.bulk_create(key_parts_to_create)
        ArchivedProjectPublicKeyEncryptedKeyPart.objects.bulk_create(encrypted_key_parts_to_create)

        # Send signal
        sysreptor_signals.post_archive.send(sender=project.__class__, instance=project, archive=archive)

        # Delete project
        if delete_project:
            with history_context(prevent_notifications=True):
                project.delete()

        return archive

    @transaction.atomic()
    def restore_project(self, archive):
        from sysreptor.pentests.import_export import import_projects
        from sysreptor.pentests.models import PentestProject

        # Combine key parts with shamir secret sharing to decrypt the archive key
        key_parts = list(filter(lambda k: k.is_decrypted, archive.key_parts.all()))
        if len(key_parts) < archive.threshold:
            raise ValueError('Too few key parts available')
        archive_key = ShamirLarge.combine_large([(k.key_part['key_id'], b64decode(k.key_part['key'])) for k in key_parts])

        # Decrypt archive and import project
        with tempfile.SpooledTemporaryFile(max_size=settings.FILE_UPLOAD_MAX_MEMORY_SIZE, mode='w+b') as f:
            with crypto.open(archive.file, mode='rb', key=crypto.EncryptionKey(id=None, key=archive_key)) as c:
                while chunk := c.read(settings.FILE_UPLOAD_MAX_MEMORY_SIZE):
                    f.write(chunk)
            f.seek(0)
            projects = import_projects(f)

        # Add archivers as members (only relevant for global archivers)
        for k in archive.key_parts.all():
            PentestProject.objects.add_member(k.user, projects)

        # Delete archive
        archive.delete()
        return projects[0]


class UploadedFileQueryset(models.QuerySet):
    def filter_name(self, name):
        from sysreptor.pentests.models import UploadedFileBase
        return self.filter(name_hash=UploadedFileBase.hash_name(name))


class UploadedFileManagerMixin:
    def randomize_name(self, name):
        if (ext_idx := name.rfind('.')) and ext_idx != -1:
            name = name[:ext_idx] + '-' + get_random_string(8) + name[ext_idx:]
        else:
            name = name + '-' + get_random_string(8)
        return name

    def create(self, file, linked_object, name=None, **kwargs):
        # Change name when a file with the same name already exists
        base_name = normalize_filename(name or file.name or 'file')
        name = base_name
        while self.filter(linked_object=linked_object).filter_name(name).exists():
            name = self.randomize_name(base_name)

        # Randomize filename in storage to not leak information
        return super().create(file=file, name=name, linked_object=linked_object, **kwargs)


class UploadedImageManager(UploadedFileManagerMixin, models.Manager.from_queryset(UploadedFileQueryset)):
    pass


class UploadedTemplateImageManager(UploadedFileManagerMixin, models.Manager.from_queryset(UploadedFileQueryset)):
    pass


class UploadedAssetManager(UploadedFileManagerMixin, models.Manager.from_queryset(UploadedFileQueryset)):
    pass


class UploadedUserNotebookImageManager(UploadedFileManagerMixin, models.Manager.from_queryset(UploadedFileQueryset)):
    pass


class UploadedUserNotebookFileManager(UploadedFileManagerMixin, models.Manager.from_queryset(UploadedFileQueryset)):
    pass


class UploadedProjectFileManager(UploadedFileManagerMixin, models.Manager.from_queryset(UploadedFileQueryset)):
    pass

